{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook is intended to demonstrate how select registration, segmentation, and image mathematical methods of ITKTubeTK can be combined to perform multi-channel brain extraction (aka. skull stripping for patient data containing multiple MRI sequences).\n",
    "\n",
    "There are many other (probably more effective) brain extraction methods available as open-source software such as BET and BET2 in the FSL package (albeit such methods are only for single channel data).   If you need to perform brain extraction for a large collection of scans that do not contain major pathologies, please use one of those packages.   This notebook is meant to show off the capabilities of specific ITKTubeTK methods, not to demonstration how to \"solve\" brain extraction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itk\n",
    "from itk import TubeTK as ttk\n",
    "\n",
    "from itkwidgets import view\n",
    "\n",
    "import numpy as np\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "ImageType = itk.Image[itk.F, 3]\n",
    "\n",
    "InputBaseName = \"../Data/CTP-MinMax/max3\"\n",
    "\n",
    "filename = InputBaseName + \".mha\"\n",
    "im1iso = itk.imread(filename, itk.F)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 8\n",
    "readerList = [\"003\", \"010\", \"026\", \"034\", \"045\", \"056\", \"063\", \"071\"]\n",
    "\n",
    "imBase = []\n",
    "imBaseB = []\n",
    "for i in range(0,N):\n",
    "    name = \"../Data/MRI-Normals/Normal\"+readerList[i]+\"-FLASH.mha\"\n",
    "    nameB = \"../Data/MRI-Normals/Normal\"+readerList[i]+\"-FLASH-Brain.mha\"\n",
    "    imBaseTmp = itk.imread(name, itk.F)\n",
    "    imBaseBTmp = itk.imread(nameB, itk.F)\n",
    "    imBase.append(imBaseTmp)\n",
    "    imBaseB.append(imBaseBTmp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "00a2aaa203384f1c955d2ef4e35bf32f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "view(im1iso)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2e30260b8b6d4970bde69d81ce6c3610",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "imMath = ttk.ImageMath.New(Input=im1iso)\n",
    "#imMath.Threshold(-4000,-500,1,0)\n",
    "#headMask = imMath.GetOutput()\n",
    "imMath.SetInput(im1iso)\n",
    "#imMath.IntensityWindow(0,1000,1000,0)\n",
    "#imMath.ReplaceValuesOutsideMaskRange(headMask,-0.5,0.5,-500)\n",
    "imMath.Blur(1)\n",
    "imMath.NormalizeMeanStdDev()\n",
    "imMath.IntensityWindow(-5,5,-500,500)\n",
    "im1isoBlur = imMath.GetOutput()\n",
    "view(im1isoBlur)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "***  0 : 12% : 35s  ***\n",
      "***  1 : 25% : 33s  ***\n",
      "***  2 : 38% : 30s  ***\n",
      "***  3 : 50% : 33s  ***\n",
      "***  4 : 62% : 37s  ***\n",
      "***  5 : 75% : 32s  ***\n",
      "***  6 : 88% : 34s  ***\n",
      "***  7 : 100% : 32s  ***\n"
     ]
    }
   ],
   "source": [
    "RegisterImagesType = ttk.RegisterImages[ImageType]\n",
    "regB = []\n",
    "regBB = []\n",
    "for i in range(0,N):\n",
    "    start = time.time()\n",
    "    \n",
    "    imMath.SetInput(imBase[i])\n",
    "    imMath.Blur(1)\n",
    "    imMath.NormalizeMeanStdDev()\n",
    "    imMath.IntensityWindow(-5,5,-500,500)\n",
    "    imBaseBlur = imMath.GetOutput()\n",
    "    \n",
    "    regBTo1 = RegisterImagesType.New(FixedImage=imBaseBlur, MovingImage=im1isoBlur)\n",
    "    \n",
    "    regBTo1.SetRigidMaxIterations(3000)\n",
    "    regBTo1.SetAffineMaxIterations(3000)\n",
    "    \n",
    "    regBTo1.SetExpectedRotationMagnitude(0.2)\n",
    "    regBTo1.SetExpectedScaleMagnitude(0.25)\n",
    "    regBTo1.SetExpectedSkewMagnitude(0.01)\n",
    "    regBTo1.SetExpectedOffsetMagnitude(40) \n",
    "\n",
    "    regBTo1.SetRigidSamplingRatio(0.1)\n",
    "    regBTo1.SetAffineSamplingRatio(0.1)\n",
    "    \n",
    "    regBTo1.SetSampleFromOverlap(True)\n",
    "    \n",
    "    regBTo1.SetInitialMethodEnum(\"INIT_WITH_IMAGE_CENTERS\")\n",
    "    regBTo1.SetRegistration(\"PIPELINE_AFFINE\")\n",
    "    regBTo1.SetMetric(\"MATTES_MI_METRIC\")\n",
    "    \n",
    "    #regBTo1.SetReportProgress(True)\n",
    "\n",
    "    regBTo1.Update()\n",
    "    \n",
    "    tfm = regBTo1.GetCurrentMatrixTransform()\n",
    "    tfmInv = tfm.GetInverseTransform()\n",
    "    \n",
    "    resm = ttk.ResampleImage.New(Input=imBase[i])\n",
    "    resm.SetMatchImage(im1iso)\n",
    "    resm.SetTransform(tfmInv)\n",
    "    resm.SetLoadTransform(True)\n",
    "    resm.Update()\n",
    "    img = resm.GetOutput()\n",
    "    regB.append( img )\n",
    "\n",
    "    resm = ttk.ResampleImage.New(Input=imBaseB[i])\n",
    "    resm.SetMatchImage(im1iso)\n",
    "    resm.SetTransform(tfmInv)\n",
    "    resm.SetLoadTransform(True)\n",
    "    resm.Update()\n",
    "    img = resm.GetOutput()\n",
    "    regBB.append( img )\n",
    "    \n",
    "    end = time.time()\n",
    "    \n",
    "    percent = (i + 1) / N * 100\n",
    "    print('***  ' + str(i) + ' : ' + str(round(percent)) + '% : ' + str(round(end-start)) + 's  ***')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e6f60a0ba2d54ddea21e6d28dee73ed4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "imMath.SetInput(regB[1])\n",
    "imMath.AddImages(im1iso,20,1)\n",
    "img = imMath.GetOutput()\n",
    "view( img )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "845f58d02dd7433bb522cbae5b84b1f6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "regBBT = []\n",
    "for i in range(0,N):\n",
    "    imMath.SetInput(regBB[i])\n",
    "    imMath.Threshold(0,1,0,1)\n",
    "    img = imMath.GetOutput()\n",
    "    if i==0:\n",
    "        imMath.SetInput( img )\n",
    "        imMath.AddImages( img, 1.0/N, 0 )\n",
    "        sumBBT = imMath.GetOutput()\n",
    "    else:\n",
    "        imMath.SetInput( sumBBT )\n",
    "        imMath.AddImages( img, 1, 1.0/N )\n",
    "        sumBBT = imMath.GetOutput()\n",
    "        \n",
    "view(sumBBT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "imMath.SetInput(sumBBT)\n",
    "imMath.Threshold(0.85,1.1,1,0)\n",
    "imMath.Dilate(5,1,0)\n",
    "imMath.Erode(25,1,0)\n",
    "brainInside = imMath.GetOutput()\n",
    "\n",
    "imMath.SetInput( sumBBT )\n",
    "imMath.Threshold(0,0,1,0)\n",
    "imMath.Erode(1,1,0)\n",
    "brainOutsideAll = imMath.GetOutput()\n",
    "imMath.Erode(20,1,0)\n",
    "imMath.AddImages(brainOutsideAll, -1, 1)\n",
    "brainOutside = imMath.GetOutput()\n",
    "\n",
    "imMath.AddImages(brainInside,1,2)\n",
    "brainCombinedMask = imMath.GetOutputUChar()\n",
    "brainCombinedMaskF = imMath.GetOutput()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2883c7aaea894cf19617c2edcc58ed0b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "imMath.SetInput(brainCombinedMaskF)\n",
    "imMath.AddImages(im1iso, 100, 1)\n",
    "brainCombinedMaskView = imMath.GetOutput()\n",
    "view(brainCombinedMaskView)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "LabelMapType = itk.Image[itk.UC,3]\n",
    "\n",
    "segmenter = ttk.SegmentConnectedComponentsUsingParzenPDFs[ImageType,LabelMapType].New()\n",
    "segmenter.SetFeatureImage( im1iso )\n",
    "segmenter.SetInputLabelMap( brainCombinedMask )\n",
    "segmenter.SetObjectId( 2 )\n",
    "segmenter.AddObjectId( 1 )\n",
    "segmenter.SetVoidId( 0 )\n",
    "segmenter.SetErodeDilateRadius( 10 )\n",
    "segmenter.SetHoleFillIterations( 40 )\n",
    "segmenter.Update()\n",
    "segmenter.ClassifyImages()\n",
    "brainCombinedMaskClassified = segmenter.GetOutputLabelMap()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "83328456ed7944b095c4d1ec429fe679",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageUC3; pr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "view(brainCombinedMaskClassified)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "cast = itk.CastImageFilter[LabelMapType, ImageType].New()\n",
    "cast.SetInput(brainCombinedMaskClassified)\n",
    "cast.Update()\n",
    "brainMaskF = cast.GetOutput()\n",
    "\n",
    "brainMath = ttk.ImageMath[ImageType,ImageType].New(Input = brainMaskF)\n",
    "brainMath.Threshold(2,2,1,0)\n",
    "brainMath.Erode(1,1,0)\n",
    "brainMaskD = brainMath.GetOutput()\n",
    "brainMath.SetInput( im1iso )\n",
    "brainMath.ReplaceValuesOutsideMaskRange( brainMaskD, 1, 1, 0)\n",
    "brain = brainMath.GetOutput()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "65e2d378f0b04194bc46bc76e610b24b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "view(brain)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "writer = itk.ImageFileWriter[ImageType].New(Input = brain)\n",
    "filename = InputBaseName + \"-Brain.mha\"\n",
    "writer.SetFileName(filename)\n",
    "writer.SetUseCompression(True)\n",
    "writer.Update()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
